# 2.5 - 3 hour run time
from datasets import load_dataset, DownloadConfig
import ahocorasick
import pandas as pd
import warnings
import os
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

pol = load_dataset('mlburnham/Pol_NLI', split='all')

# drop text generated after DeBERTa was trained
excluded_sources = [
    "mlburnham/anthropic_persuasion",
    "mlburnham/argument_quality_ranking_entailment",
    "mlburnham/hatespeech_entailment",
    "mlburnham/violent_hatespeech_entailment",
    "mlburnham/dehumanizing_hatespeech_entailment",
    "mlburnham/targeted_hatespeech_entailment",
    "mlburnham/ibm_claimstance_entailment",
    "mlburnham/ibm_claimstance_topic_entailment",
    "mlburnham/polarizing_rhetoric"    
]
# Filter the dataset
pol = pol.filter(lambda x: x['dataset'] not in excluded_sources)

# build automaton from Pol_NLI examples
A = ahocorasick.Automaton()
for idx, ex in enumerate(pol):
    text = ex['premise'].lower().strip()
    # insert the full text (or you could split into shorter phrases)
    A.add_word(text, (idx, text))
A.make_automaton()

###################
# Scan OpenWebText
###################
matches = []
for record in load_dataset('Skylion007/openwebtext', split = 'train', streaming=True, trust_remote_code=True):
    doc = record['text'].lower()
    #doc = record.lower()
    for end_index, (pol_idx, pat) in A.iter(doc):
        # end_index is where `pat` ends in `doc`
        matches.append((pol_idx, 'OpenWebText', record.get('id')))
        # (optional) break if you only need one match per doc

print(f"Found {len(matches)} substring matches in OpenWebText.")

docs = pol.to_pandas()
match_nums = [match[0] for match in matches]
matchdf = docs.loc[list(set(match_nums)),]

###################
# Scan CC News
###################
ccmatches = []
for record in load_dataset('vblagoje/cc_news', split = 'train', streaming=True, trust_remote_code=True):
    doc = record['text'].lower()
    #doc = record.lower()
    for end_index, (pol_idx, pat) in A.iter(doc):
        # end_index is where `pat` ends in `doc`
        ccmatches.append((pol_idx, 'CC News', record.get('id')))
        # (optional) break if you only need one match per doc

print(f"Found {len(ccmatches)} substring matches in CC News.")
ccmatch_nums = [match[0] for match in ccmatches]
ccmatchdf = docs.loc[list(set(ccmatch_nums)),]

###################
# Scan STORIES
###################
stmatches = []
for record in load_dataset('lucadiliello/STORIES', split = 'train', streaming=True, trust_remote_code=True):
    doc = record['text'].lower()
    #doc = record.lower()
    for end_index, (pol_idx, pat) in A.iter(doc):
        # end_index is where `pat` ends in `doc`
        stmatches.append((pol_idx, 'STORIES', record.get('id')))
        # (optional) break if you only need one match per doc

print(f"Found {len(stmatches)} substring matches in OpenWebText.")
stmatch_nums = [match[0] for match in stmatches]
stmatchdf = docs.loc[list(set(stmatch_nums)),]

###################
# Compile Results
###################
all_matches = pd.concat([matchdf, ccmatchdf, stmatchdf])
all_matches = all_matches.drop_duplicates(subset = 'premise')
print(f"Found {all_matches.shape[0]} substring matches across all datasets.")